/*
 * demo_GMR01.cpp
 *
 * Gaussian mixture model (GMM) with time-based Gaussian mixture regression
 * (GMR) used for reproduction.
 *
 * If this code is useful for your research, please cite the related publication:
 * @incollection{Calinon19MM,
 * 	 author="Calinon, S.",
 * 	 title="Mixture Models for the Analysis, Edition, and Synthesis of Continuous Time Series",
 * 	 booktitle="Mixture Models and Applications",
 * 	 publisher="Springer",
 * 	 editor="Bouguila, N. and Fan, W.", 
 * 	 year="2019",
 * 	 pages="39--57",
 * 	 doi="10.1007/978-3-030-23876-6_3"
 * }
 *
 * Authors: Sylvain Calinon, Philip Abbet
 *
 * This file is part of PbDlib, https://www.idiap.ch/software/pbdlib/
 * 
 * PbDlib is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 3 as
 * published by the Free Software Foundation.
 * 
 * PbDlib is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with PbDlib. If not, see <https://www.gnu.org/licenses/>.
 */


#include <stdio.h>
#include <armadillo>

#include <gfx2.h>
#include <gfx_ui.h>
#include <GLFW/glfw3.h>
#include <imgui.h>
#include <imgui_impl_glfw_gl2.h>

using namespace arma;


/***************************** ALGORITHM SECTION *****************************/

typedef std::vector<vec> vector_list_t;
typedef std::vector<mat> matrix_list_t;


//-----------------------------------------------------------------------------
// Contains all the parameters used by the algorithm. Some of them are
// modifiable through the UI, others are hard-coded.
//-----------------------------------------------------------------------------
struct parameters_t {
	int	  nb_states;		// Number of components in the GMM
	int	  nb_data;			// Number of datapoints in a trajectory
	float dt;				// Time step (without rescaling, large values such
							// as 1 has the advantage of creating clusers based
							// on position information)
};


//-----------------------------------------------------------------------------
// Model trained using the algorithm
//-----------------------------------------------------------------------------
struct model_t {
	parameters_t  parameters; // Parameters used to train the model

	int			  nb_var;	 // Number of variables [t,x1,x2]
	vector_list_t mu;
	matrix_list_t sigma;
	vec           priors;
};


//-----------------------------------------------------------------------------
// Likelihood of datapoint(s) to be generated by a Gaussian parameterized by
// center and covariance.
//
// Inputs:
//   - Data:  D x N array representing N datapoints of D dimensions.
//   - Mu:    D x 1 vector representing the center of the Gaussian.
//   - Sigma: D x D array representing the covariance matrix of the Gaussian.
//
// Output:
//   - prob:  1 x N vector representing the likelihood of the N datapoints.
//-----------------------------------------------------------------------------
arma::vec gaussPDF(mat Data, colvec Mu, mat Sigma) {

	int nbVar = Data.n_rows;
	int nbData = Data.n_cols;
	Data = Data.t() - repmat(Mu.t(), nbData, 1);

	vec prob = sum((Data * inv(Sigma)) % Data, 1);

	prob = exp(-0.5 * prob) / sqrt(pow((2 * datum::pi), nbVar) * det(Sigma) + DBL_MIN);

	return prob;
}

//-----------------------------------------------------------------------------
// Training of the model
//-----------------------------------------------------------------------------
void learn(const matrix_list_t& demos, model_t &model) {

	model.nb_var = 3;

	// Initialization of Gaussian Mixture Model (GMM) parameters by clustering 
	// the data into equal bins based on the first variable (time steps).
	const float diag_reg_fact = 1e-4f;

	vec timing_sep = linspace<vec>(
		demos[0](0, 0), demos[0](0, demos[0].n_cols - 1), model.parameters.nb_states + 1
	);

	model.mu.clear();
	model.sigma.clear();
	model.priors = vec(model.parameters.nb_states);

	mat data(model.nb_var, model.parameters.nb_data * demos.size());
	for (unsigned int m = 0; m < demos.size(); ++m) {
		data(span::all, span(m * model.parameters.nb_data,
							 (m + 1) * model.parameters.nb_data - 1)) =
			demos[m];
	}

	for (unsigned int i = 0; i < model.parameters.nb_states; ++i) {
		uvec idtmp = find( (data(0, span::all) >= timing_sep(i)) &&
						   (data(0, span::all) < timing_sep(i + 1)) );

		model.priors(i) = idtmp.size();
		model.mu.push_back(mean(data.cols(idtmp), 1));

		mat sigma = cov(data.cols(idtmp).t());

		// Optional regularization term to avoid numerical instability
		sigma = sigma + eye(model.nb_var, model.nb_var) * diag_reg_fact;

		model.sigma.push_back(sigma);
	}

	model.priors = model.priors / sum(model.priors);


	// Training of a Gaussian mixture model (GMM) with an expectation-maximization
	// (EM) algorithm
	const int nb_max_steps = 100;			// Maximum number of iterations allowed
	const int nb_min_steps = 5;				// Minimum number of iterations allowed
	const double max_diff_log_likelihood = 1e-4;	// Likelihood increase threshold
													// to stop the algorithm

	std::vector<double> log_likelihoods;

	for (int iter = 0; iter < nb_max_steps; ++iter) {

		// E-step
		mat L = ones(model.parameters.nb_states, data.n_cols);

		for (int i = 0; i < model.parameters.nb_states; ++i) {
			L(i, span::all) = model.priors(i) * mat(gaussPDF(data, model.mu[i], model.sigma[i])).t();
		}

		mat gamma = L / repmat(sum(L, 0) + DBL_MIN, model.parameters.nb_states, 1);
		mat gamma2 = gamma / repmat(sum(gamma, 1), 1, data.n_cols);


		// M-step
		for (int i = 0; i < model.parameters.nb_states; ++i) {

			// Update priors
			model.priors(i) = sum(gamma(i, span::all)) / data.n_cols;

			// Update mu
			model.mu[i] = data * gamma2(i, span::all).t();

			// Update sigma
			mat data_tmp = data - repmat(model.mu[i], 1, data.n_cols);
			model.sigma[i] = data * diagmat(gamma2(i, span::all)) * data_tmp.t() +
							 eye(data.n_rows, data.n_rows) * diag_reg_fact;
		}

		// Compute average log-likelihood
		log_likelihoods.push_back(vec(sum(log(sum(L, 0)), 1))[0] / data.n_cols);

		// Stop the algorithm if EM converged (small change of log-likelihood)
		if (iter >= nb_min_steps) {
			if (log_likelihoods[iter] - log_likelihoods[iter - 1] < max_diff_log_likelihood)
				break;
		}
	}
}


//-----------------------------------------------------------------------------
// Gaussian mixture regression (GMR)
//-----------------------------------------------------------------------------
void compute_GMR(const model_t& model, const vec& time_steps, mat &points,
				 matrix_list_t &sigma_out) {

	const int in = 0;
	const span out(1, model.nb_var - 1);

	const int nb_data = time_steps.n_elem;
	const int nb_var_out = out.b - out.a + 1;

	const float diag_reg_fact = 1e-8f;

	mat mu_tmp = zeros(nb_var_out, model.parameters.nb_states);
	points = zeros(nb_var_out, nb_data);
	sigma_out.clear();

	mat H(model.parameters.nb_states, nb_data);

	for (int t = 0; t < nb_data; ++t) {

		// Compute activation weight
		for (int i = 0; i < model.parameters.nb_states; ++i) {
			mat time_step(1, 1);
			time_step(0, 0) = time_steps(t);

			vec mu(1);
			mu(0) = model.mu[i](in);

			mat sigma(1, 1);
			sigma(0, 0) = model.sigma[i](in, in);

			H(i, t) = model.priors(i) * gaussPDF(time_step, mu, sigma)[0];
		}

		H(span::all, t) = H(span::all, t) / sum(H(span::all, t) + DBL_MIN);

		// Compute conditional means
		for (int i = 0; i < model.parameters.nb_states; ++i) {
			mu_tmp(span::all, i) = model.mu[i](out) +
								   model.sigma[i](out, in) / model.sigma[i](in, in) *
									   (time_steps(t) * model.mu[i](in));

			points(span::all, t) = points(span::all, t) + H(i, t) * mu_tmp(span::all, i);
		}

		// Compute conditional covariances
		mat sigma = zeros(nb_var_out, nb_var_out);

		for (int i = 0; i < model.parameters.nb_states; ++i) {
			mat sigma_tmp = model.sigma[i](out, out) -
							model.sigma[i](out, in) / model.sigma[i](in, in) *
								model.sigma[i](in, out);

			sigma = sigma + H(i, t) * (sigma_tmp + mu_tmp(span::all, i) * mu_tmp(span::all, i).t());
		}

		sigma = sigma - points(span::all, t) * points(span::all, t).t() +
				eye(nb_var_out, nb_var_out) * diag_reg_fact;

		sigma_out.push_back(sigma);
	}
}


/****************************** HELPER FUNCTIONS *****************************/

static void error_callback(int error, const char* description) {
	fprintf(stderr, "Error %d: %s\n", error, description);
}


//-----------------------------------------------------------------------------
// Contains all the informations about a viewport
//-----------------------------------------------------------------------------
struct viewport_t {
	int x;
	int y;
	int width;
	int height;

	// Projection matrix parameters
	arma::vec projection_top_left;
	arma::vec projection_bottom_right;
	double projection_near;
	double projection_far;
};


//-----------------------------------------------------------------------------
// Helper function to setup a viewport
//-----------------------------------------------------------------------------
void setup_viewport(viewport_t* viewport, int x, int y, int width, int height,
					double near_distance = -1.0, double far_distance = 1.0) {

	viewport->x = x;
	viewport->y = y;
	viewport->width = width;
	viewport->height = height;
	viewport->projection_top_left = vec({ (double) -width / 2,
										  (double) height / 2 });
	viewport->projection_bottom_right = vec({ (double) width / 2,
											  (double) -height / 2 });
	viewport->projection_near = near_distance;
	viewport->projection_far = far_distance;
}


//-----------------------------------------------------------------------------
// Converts some coordinates from UI-space to OpenGL-space, taking the
// coordinates of a viewport into account
//-----------------------------------------------------------------------------
arma::vec ui2fb(const arma::vec& coords, const gfx2::window_size_t& window_size,
				const viewport_t& viewport) {
	arma::vec result = coords;

	// ui -> viewport
	result(0) = coords(0) * (float) window_size.fb_width / (float) window_size.win_width - viewport.x;
	result(1) = (window_size.win_height - coords(1)) *
				(float) window_size.fb_height / (float) window_size.win_height - viewport.y;

	// viewport -> fb
	result(0) = result(0) - (float) viewport.width * 0.5f;
	result(1) = result(1) - (float) viewport.height * 0.5f;

	return result;
}


//-----------------------------------------------------------------------------
// Converts some coordinates from OpenGL-space to UI-space, taking the
// coordinates of a viewport into account
//-----------------------------------------------------------------------------
arma::vec fb2ui(const arma::vec& coords, const gfx2::window_size_t& window_size,
				const viewport_t& viewport) {
	arma::vec result = coords;

	// fb -> viewport
	result(0) = coords(0) + (float) viewport.width * 0.5f;
	result(1) = coords(1) + (float) viewport.height * 0.5f;

	// viewport -> ui
	result(0) = (result(0) + viewport.x) * (float) window_size.win_width / (float) window_size.fb_width;

	result(1) = window_size.win_height - (result(1) + viewport.y) * (float) window_size.win_height / (float) window_size.fb_height;

	return result;
}


//-----------------------------------------------------------------------------
// Colors of the displayed lines and gaussians
//-----------------------------------------------------------------------------
const mat COLORS({
	{ 0.0,  0.0,  1.0  },
	{ 0.0,  0.5,  0.0  },
	{ 1.0,  0.0,  0.0  },
	{ 0.0,  0.75, 0.75 },
	{ 0.75, 0.0,  0.75 },
	{ 0.75, 0.75, 0.0  },
	{ 0.25, 0.25, 0.25 },
});


//-----------------------------------------------------------------------------
// Create a demonstration (with a length of 'timestamps.size()') from a
// trajectory (of any length)
//-----------------------------------------------------------------------------
mat sample_trajectory(const vector_list_t& trajectory, const vec& time_steps) {

	// Resampling of the trajectory
	vec x(trajectory.size());
	vec y(trajectory.size());
	vec x2(trajectory.size());
	vec y2(trajectory.size());

	for (size_t i = 0; i < trajectory.size(); ++i) {
		x(i) = trajectory[i](0);
		y(i) = trajectory[i](1);
	}

	vec from_indices = linspace<vec>(0, trajectory.size() - 1, trajectory.size());
	vec to_indices = linspace<vec>(0, trajectory.size() - 1, time_steps.size());

	interp1(from_indices, x, to_indices, x2, "*linear");
	interp1(from_indices, y, to_indices, y2, "*linear");

	// Create the demonstration
	mat demo(3, time_steps.size());
	for (int i = 0; i < time_steps.size(); ++i) {
		demo(0, i) = time_steps[i];
		demo(1, i) = x2[i];
		demo(2, i) = y2[i];
	}

	return demo;
}


//-----------------------------------------------------------------------------
// Contains all the needed infos about the state of the application (values of
// the parameters modifiable via the UI, which action the user is currently
// doing, ...)
//-----------------------------------------------------------------------------
struct gui_state_t {
	// Indicates if the user is currently drawing a new demonstration
	bool is_drawing_demonstration;

	// Indicates if the parameters dialog is displayed
	bool is_parameters_dialog_displayed;

	// Indicates if the parameters were modified through the UI
	bool are_parameters_modified;

	// Indicates if the reproductions must be recomputed
	bool must_recompute_GMR;

	// Parameters modifiable via the UI (they correspond to the ones declared
	// in parameters_t)
	int parameter_nb_states;
	int parameter_nb_data;
};


//-----------------------------------------------------------------------------
// Render the "demonstrations & model" viewport
//-----------------------------------------------------------------------------
void draw_demos_viewport(const viewport_t& viewport,
						 const vector_list_t& current_trajectory,
						 const matrix_list_t& demonstrations,
						 const model_t& model) {

	glViewport(viewport.x, viewport.y, viewport.width, viewport.height);
	glScissor(viewport.x, viewport.y, viewport.width, viewport.height);
	glClearColor(0.7f, 0.7f, 0.7f, 0.0f);
	glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);

	glMatrixMode(GL_PROJECTION);
	glLoadIdentity();
	glOrtho(viewport.projection_top_left(0), viewport.projection_bottom_right(0),
			viewport.projection_bottom_right(1), viewport.projection_top_left(1),
			viewport.projection_near, viewport.projection_far);

	glMatrixMode(GL_MODELVIEW);
	glLoadIdentity();

	// Draw the GMM states
	if (!model.mu.empty()) {
		for (int i = 0; i < model.parameters.nb_states; ++i) {
			glClear(GL_DEPTH_BUFFER_BIT);

			gfx2::draw_gaussian(conv_to<fvec>::from(COLORS.row(i % 10).t()),
								model.mu[i](span(1, model.nb_var - 1), span::all),
								model.sigma[i](span(1, model.nb_var - 1),
											   span(1, model.nb_var - 1))
			);
		}

		glClear(GL_DEPTH_BUFFER_BIT);
	}

	// Draw the currently created demonstration (if any)
	if (current_trajectory.size() > 1)
		gfx2::draw_line(arma::fvec({0.33f, 0.97f, 0.33f}), current_trajectory);

	// Draw the demonstrations
	int color_index = 0;
	for (size_t i = 0; i < demonstrations.size(); ++i) {
		arma::mat datapoints = demonstrations[i](span(1, 2), span::all);

		arma::fvec color = arma::conv_to<arma::fvec>::from(COLORS.row(color_index));

		gfx2::draw_line(color, datapoints);

		++color_index;
		if (color_index >= COLORS.n_rows)
			color_index = 0;
	}
}


//-----------------------------------------------------------------------------
// Render a "reproduction" viewport
//-----------------------------------------------------------------------------
void draw_GMR_viewport(const viewport_t& viewport, const mat& points,
					   const std::vector<gfx2::model_t>& models) {

	glViewport(viewport.x, viewport.y, viewport.width, viewport.height);
	glScissor(viewport.x, viewport.y, viewport.width, viewport.height);
	glClearColor(0.9f, 0.9f, 0.9f, 0.0f);
	glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);

	glMatrixMode(GL_PROJECTION);
	glLoadIdentity();
	glOrtho(viewport.projection_top_left(0), viewport.projection_bottom_right(0),
			viewport.projection_bottom_right(1), viewport.projection_top_left(1),
			viewport.projection_near, viewport.projection_far);

	glMatrixMode(GL_MODELVIEW);
	glLoadIdentity();

	if (!models.empty()) {
		for (int i = 0; i < models.size(); ++i) {
			glClear(GL_DEPTH_BUFFER_BIT);
			gfx2::draw(models[i]);
		}

		glClear(GL_DEPTH_BUFFER_BIT);

		glLineWidth(4.0f);
		gfx2::draw_line(arma::fvec({0.0f, 0.4f, 0.0f}), points);
		glLineWidth(1.0f);
	}
}


//-----------------------------------------------------------------------------
// Returns the dimensions that a plot should have inside the provided viewport
//-----------------------------------------------------------------------------
ivec get_plot_dimensions(const viewport_t& viewport) {

	const int MARGIN = 50;

	ivec result(2);
	result(0) = viewport.width - 2 * MARGIN;
	result(1) = viewport.height - 2 * MARGIN;

	return result;
}


//-----------------------------------------------------------------------------
// Render a "timeline" viewport
//-----------------------------------------------------------------------------
void draw_timeline_viewport(const gfx2::window_size_t& window_size,
							const viewport_t& viewport,
							const matrix_list_t& demonstrations,
							const model_t& model,
							const mat& GMR_points, matrix_list_t GMR_sigma,
							unsigned int dimension) {

	glViewport(viewport.x, viewport.y, viewport.width, viewport.height);
	glScissor(viewport.x, viewport.y, viewport.width, viewport.height);
	glClearColor(0.9f, 0.9f, 0.9f, 0.0f);
	glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);

	glMatrixMode(GL_PROJECTION);
	glLoadIdentity();
	glOrtho(viewport.projection_top_left(0), viewport.projection_bottom_right(0),
			viewport.projection_bottom_right(1), viewport.projection_top_left(1),
			viewport.projection_near, viewport.projection_far);

	glMatrixMode(GL_MODELVIEW);
	glLoadIdentity();

	ivec plot_dimensions = get_plot_dimensions(viewport);

	ivec plot_top_left({ -plot_dimensions(0) / 2, plot_dimensions(1) / 2 });
	ivec plot_bottom_right({ plot_dimensions(0) / 2, -plot_dimensions(1) / 2 });

	// Axis labels
	ui::begin("Text");

	vec coords = fb2ui(vec({ -20.0, double(-viewport.height / 2 + 45) }),
					   window_size, viewport);
	ui::text(ImVec2(coords(0), coords(1)), "t", ImVec4(0,0,0,1));

	std::stringstream label;
	label << "x" << dimension;

	coords = fb2ui(vec({ double(-viewport.width / 2) + 10, -20.0 }),
				   window_size, viewport);
	ui::text(ImVec2(coords(0), coords(1)), label.str(), ImVec4(0,0,0,1));

	ui::end();

	// Draw the axes
	gfx2::draw_line(fvec({0.0f, 0.0f, 0.0f}),
					mat({ { double(plot_top_left(0)), double(plot_bottom_right(0)) },
						  { double(plot_bottom_right(1)), double(plot_bottom_right(1)) }
						})
	);

	gfx2::draw_line(fvec({0.0f, 0.0f, 0.0f}),
					mat({ { double(plot_top_left(0)), double(plot_top_left(0)) },
						  { double(plot_bottom_right(1)), double(plot_top_left(1)) }
						})
	);

	// Check if there is something to display
	if (demonstrations.empty())
		return;

	// Draw the GMR
	double scale_x = (double) plot_dimensions(0) / demonstrations[0](0, demonstrations[0].n_cols - 1);
	double scale_y = (double) plot_dimensions(1) / viewport.height;

	mat top_vertices(2, GMR_points.n_cols);
	mat bottom_vertices(2, GMR_points.n_cols);

	for (int j = 0; j < GMR_points.n_cols; ++j) {
		top_vertices(0, j) = demonstrations[0](0, j) * scale_x - plot_dimensions(0) / 2;
		top_vertices(1, j) = (GMR_points(dimension - 1, j) +
							 sqrt(GMR_sigma[j](dimension - 1, dimension - 1))) * scale_y;

		bottom_vertices(0, j) = top_vertices(0, j);
		bottom_vertices(1, j) = (GMR_points(dimension - 1, j) -
								sqrt(GMR_sigma[j](dimension - 1, dimension - 1))) * scale_y;
	}

	mat gmr_points(2, (GMR_points.n_cols - 1) * 6);

	for (int j = 0; j < GMR_points.n_cols - 1; ++j) {
		gmr_points(span::all, j * 6 + 0) = top_vertices(span::all, j);
		gmr_points(span::all, j * 6 + 1) = bottom_vertices(span::all, j);
		gmr_points(span::all, j * 6 + 2) = top_vertices(span::all, j + 1);

		gmr_points(span::all, j * 6 + 3) = top_vertices(span::all, j + 1);
		gmr_points(span::all, j * 6 + 4) = bottom_vertices(span::all, j);
		gmr_points(span::all, j * 6 + 5) = bottom_vertices(span::all, j + 1);
	}

	gfx2::model_t gmr_model = gfx2::create_mesh(fvec({ 0.0f, 0.8f, 0.0f, 0.05f }), gmr_points);
	gmr_model.use_one_minus_src_alpha_blending = true;
	gfx2::draw(gmr_model);

	// Draw the GMM states
	for (int i = 0; i < model.parameters.nb_states; ++i) {
		glClear(GL_DEPTH_BUFFER_BIT);

		vec mu = model.mu[i].elem(uvec({ 0, dimension }));
		mu(0) = mu(0) * scale_x - plot_dimensions(0) / 2;
		mu(1) *= scale_y;

		mat sigma = model.sigma[i].submat(uvec({ 0, dimension }), uvec({ 0, dimension }));
		sigma(0, 0) = sigma(0, 0) * (scale_x * scale_x);
		sigma(0, 1) = sigma(0, 1) * scale_x;
		sigma(1, 0) = sigma(1, 0) * scale_x;

		gfx2::draw_gaussian(conv_to<fvec>::from(COLORS.row(i % 10).t()), mu, sigma,
							true, false);
	}

	glClear(GL_DEPTH_BUFFER_BIT);

	// Draw the demonstrations
	int color_index = 0;
	for (size_t i = 0; i < demonstrations.size(); ++i) {
		arma::mat datapoints = demonstrations[i].rows(uvec({ 0, dimension }));

		datapoints(0, span::all) = datapoints(0, span::all) * scale_x - plot_dimensions(0) / 2;
		datapoints(1, span::all) *= scale_y;

		arma::fvec color = arma::conv_to<arma::fvec>::from(COLORS.row(color_index));

		gfx2::draw_line(color, datapoints);

		++color_index;
		if (color_index >= COLORS.n_rows)
			color_index = 0;
	}

	// Draw the GRM result
	mat points(2, GMR_points.n_cols);
	points(0, span::all) = demonstrations[0](0, span::all);
	points(1, span::all) = GMR_points(dimension - 1, span::all);

	points(0, span::all) = points(0, span::all) * scale_x - plot_dimensions(0) / 2;
	points(1, span::all) *= scale_y;

	glLineWidth(4.0f);
	gfx2::draw_line(arma::fvec({0.0f, 0.4f, 0.0f}), points);
	glLineWidth(1.0f);
}


/******************************* MAIN FUNCTION *******************************/

int main(int argc, char **argv) {
	arma_rng::set_seed_random();

	// Model
	model_t model;

	// Parameters
	model.parameters.nb_states = 6;
	model.parameters.nb_data   = 200;
	model.parameters.dt		   = 0.001f;


	// Take 4k screens into account (framebuffer size != window size)
	gfx2::window_size_t window_size;
	window_size.win_width = 800;
	window_size.win_height = 800;
	window_size.fb_width = -1;	// Will be known later
	window_size.fb_height = -1;
	int viewport_width = 0;
	int viewport_height = 0;


	// Initialise GLFW
	glfwSetErrorCallback(error_callback);

	if (!glfwInit())
		return -1;

	glfwWindowHint(GLFW_SAMPLES, 4);
	glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 2);
	glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 1);

	// Open a window and create its OpenGL context
	GLFWwindow* window = gfx2::create_window_at_optimal_size(
		"Demo Conditioning on trajectory distributions",
		window_size.win_width, window_size.win_height
	);

	glfwMakeContextCurrent(window);


	// Setup OpenGL
	gfx2::init();
	glEnable(GL_SCISSOR_TEST);
	glEnable(GL_DEPTH_TEST);
	glEnable(GL_CULL_FACE);
	glEnable(GL_LINE_SMOOTH);
	glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA);

	// Setup ImGui
	ImGui::CreateContext();
	ImGui_ImplGlfwGL2_Init(window, true);


	// Viewports
	viewport_t viewport_demos;
	viewport_t viewport_GMR;
	viewport_t viewport_x1;
	viewport_t viewport_x2;


	// GUI state
	gui_state_t gui_state;
	gui_state.is_drawing_demonstration = false;
	gui_state.is_parameters_dialog_displayed = false;
	gui_state.are_parameters_modified = true;
	gui_state.must_recompute_GMR = false;
	gui_state.parameter_nb_states = model.parameters.nb_states;
	gui_state.parameter_nb_data = model.parameters.nb_data;


	// List of demonstrations and GMr results
	vec time_steps;
	matrix_list_t demos;
	mat GMR_points;
	matrix_list_t GMR_sigma;
	std::vector<gfx2::model_t> GMR_models;


	// Main loop
	vector_list_t current_trajectory;
	std::vector<vector_list_t> original_trajectories;

	while (!glfwWindowShouldClose(window)) {
		glfwPollEvents();

		// Handling of the resizing of the window
		gfx2::window_result_t window_result =
			gfx2::handle_window_resizing(window, &window_size);

		if (window_result == gfx2::INVALID_SIZE)
			continue;

		if ((window_result == gfx2::WINDOW_READY) || (window_result == gfx2::WINDOW_RESIZED)) {

			viewport_width = window_size.fb_width / 2 - 1;
			viewport_height = window_size.fb_height / 2 - 1;

			// Update all the viewports
			setup_viewport(&viewport_demos, 0, window_size.fb_height - viewport_height,
						   viewport_width, viewport_height);

			setup_viewport(&viewport_GMR, window_size.fb_width - viewport_width,
						   window_size.fb_height - viewport_height,
						   viewport_width, viewport_height);

			setup_viewport(&viewport_x1, 0, 0, viewport_width, viewport_height);

			setup_viewport(&viewport_x2, window_size.fb_width - viewport_width, 0,
						   viewport_width, viewport_height);
		}


		// If the parameters changed, learn the model again
		if (gui_state.are_parameters_modified) {

			if (time_steps.size() != gui_state.parameter_nb_data) {
				demos.clear();

				time_steps = linspace<vec>(
					0, gui_state.parameter_nb_data - 1, gui_state.parameter_nb_data
				) * model.parameters.dt;

				for (size_t i = 0; i < original_trajectories.size(); ++i) {
					demos.push_back(sample_trajectory(original_trajectories[i],
													  time_steps)
					);
				}
			}

			model.parameters.nb_data = gui_state.parameter_nb_data;
			model.parameters.nb_states = gui_state.parameter_nb_states;

			if (!demos.empty()) {
				learn(demos, model);
				gui_state.must_recompute_GMR = true;
			}

			gui_state.are_parameters_modified = false;
		}


		// Recompute the GMR (if necessary)
		if (gui_state.must_recompute_GMR) {

			compute_GMR(model, time_steps, GMR_points, GMR_sigma);

			// Create one big mesh for the GMR viewport (for performance reasons)
			for (int i = 0; i < GMR_models.size(); ++i)
				gfx2::destroy(GMR_models[i]);
			
			GMR_models.clear();

			const int NB_POINTS = 60;

			mat vertices(2, NB_POINTS * 3 * GMR_sigma.size());
			mat lines(2, NB_POINTS * 2 * GMR_sigma.size());

			for (int j = 0; j < GMR_sigma.size(); ++j) {

				mat v = gfx2::get_gaussian_background_vertices(GMR_points(span::all, j),
															   GMR_sigma[j], NB_POINTS);

				vertices(span::all, span(j * NB_POINTS * 3, (j + 1) * NB_POINTS * 3 - 1)) = v;

				mat p = gfx2::get_gaussian_border_vertices(GMR_points(span::all, j),
														   GMR_sigma[j], NB_POINTS, false);

				lines(span::all, span(j * NB_POINTS * 2, (j + 1) * NB_POINTS * 2 - 1)) = p;
			}

			GMR_models.push_back(
				gfx2::create_mesh(fvec({ 0.0f, 0.8f, 0.0f, 0.1f }), vertices)
			);

			GMR_models[0].use_one_minus_src_alpha_blending = true;

			GMR_models.push_back(
				gfx2::create_line(fvec({ 0.0f, 0.4f, 0.0f, 0.1f }), lines,
								  arma::zeros<arma::fvec>(3),
								  arma::eye<arma::fmat>(4, 4), 0, false)
			);

			gui_state.must_recompute_GMR = false;
		}


		// Start the rendering
		ImGui_ImplGlfwGL2_NewFrame();

		glViewport(0, 0, window_size.fb_width, window_size.fb_height);
		glScissor(0, 0, window_size.fb_width, window_size.fb_height);
		glClearColor(0.0f, 0.0f, 0.0f, 0.0f);
		glClear(GL_COLOR_BUFFER_BIT);

		draw_demos_viewport(viewport_demos, current_trajectory, demos, model);

		draw_GMR_viewport(viewport_GMR, GMR_points, GMR_models);

		draw_timeline_viewport(window_size, viewport_x1, demos, model, GMR_points, GMR_sigma, 1);

		draw_timeline_viewport(window_size, viewport_x2, demos, model, GMR_points, GMR_sigma, 2);


		// Window: Demonstrations & model
		ImGui::SetNextWindowSize(ImVec2(window_size.win_width / 2, 36));
		ImGui::SetNextWindowPos(ImVec2(0, 0));
		ImGui::Begin("Demonstrations & model", NULL,
					 ImGuiWindowFlags_NoResize | ImGuiWindowFlags_NoSavedSettings |
					 ImGuiWindowFlags_NoMove | ImGuiWindowFlags_NoCollapse |
					 ImGuiWindowFlags_NoTitleBar
		);

		ImGui::Text("Demonstrations & model		");
		ImGui::SameLine();

		if (ImGui::Button("Clear")) {
			original_trajectories.clear();
			demos.clear();
			GMR_points = mat();
			GMR_sigma.clear();
			GMR_models.clear();
			model.mu.clear();
			model.sigma.clear();
		}

		ImGui::SameLine();
		ImGui::Text("	  ");
		ImGui::SameLine();

		if (ImGui::Button("Parameters"))
			gui_state.is_parameters_dialog_displayed = true;

		ImGui::End();


		// Window: GMR
		ImGui::SetNextWindowSize(ImVec2(window_size.win_width / 2, 36));
		ImGui::SetNextWindowPos(ImVec2(window_size.win_width - window_size.win_width / 2, 0));
		ImGui::Begin("GMR", NULL,
					 ImGuiWindowFlags_NoResize | ImGuiWindowFlags_NoSavedSettings |
					 ImGuiWindowFlags_NoMove | ImGuiWindowFlags_NoCollapse |
					 ImGuiWindowFlags_NoTitleBar
		);

		ImGui::Text("GMR");

		ImGui::End();


		// Window: Timeline x1
		ImGui::SetNextWindowSize(ImVec2(window_size.win_width / 2, 36));
		ImGui::SetNextWindowPos(ImVec2(0, window_size.win_height / 2));
		ImGui::Begin("Timeline: x1", NULL,
					 ImGuiWindowFlags_NoResize | ImGuiWindowFlags_NoSavedSettings |
					 ImGuiWindowFlags_NoMove | ImGuiWindowFlags_NoCollapse |
					 ImGuiWindowFlags_NoTitleBar
		);

		ImGui::Text("Timeline: x1");

		ImGui::End();


		// Window: Timeline x2
		ImGui::SetNextWindowSize(ImVec2(window_size.win_width / 2, 36));
		ImGui::SetNextWindowPos(ImVec2(window_size.win_width - window_size.win_width / 2,
									   window_size.win_height / 2));
		ImGui::Begin("Timeline: x2", NULL,
					 ImGuiWindowFlags_NoResize | ImGuiWindowFlags_NoSavedSettings |
					 ImGuiWindowFlags_NoMove | ImGuiWindowFlags_NoCollapse |
					 ImGuiWindowFlags_NoTitleBar
		);

		ImGui::Text("Timeline: x2");

		ImGui::End();


		// Window: Parameters
		ImGui::SetNextWindowSize(ImVec2(440, 106));
		ImGui::SetNextWindowPos(ImVec2((window_size.win_width - 440) / 2, (window_size.win_height - 106) / 2));
		ImGui::PushStyleColor(ImGuiCol_WindowBg, ImVec4(0, 0, 0, 255));

		if (gui_state.is_parameters_dialog_displayed)
			ImGui::OpenPopup("Parameters");

		if (ImGui::BeginPopupModal("Parameters", NULL,
								   ImGuiWindowFlags_NoResize |
								   ImGuiWindowFlags_NoSavedSettings)) {

			ImGui::SliderInt("Nb states", &gui_state.parameter_nb_states, 2, 20);
			ImGui::SliderInt("Nb data", &gui_state.parameter_nb_data, 100, 300);

			if (ImGui::Button("Close")) {
				ImGui::CloseCurrentPopup();
				gui_state.is_parameters_dialog_displayed = false;
				gui_state.are_parameters_modified = true;
			}

			ImGui::EndPopup();
		}

		ImGui::PopStyleColor();


		// GUI rendering
		ImGui::Render();
		ImGui_ImplGlfwGL2_RenderDrawData(ImGui::GetDrawData());

		// Swap buffers
		glfwSwapBuffers(window);

		// Keyboard input
		if (ImGui::IsKeyPressed(GLFW_KEY_ESCAPE))
			break;


		if (!gui_state.is_drawing_demonstration && !gui_state.is_parameters_dialog_displayed) {
			// Left click: start a new demonstration (only if not on the UI and in the
			// demonstrations viewport)
			if (ImGui::IsMouseClicked(GLFW_MOUSE_BUTTON_1)) {
				double mouse_x, mouse_y;
				glfwGetCursorPos(window, &mouse_x, &mouse_y);

				if ((mouse_x <= window_size.win_width / 2) &&
					(mouse_y > 36) && (mouse_y <= window_size.win_height / 2))
				{
					gui_state.is_drawing_demonstration = true;

					vec coords = ui2fb({ mouse_x, mouse_y }, window_size, viewport_demos);
					current_trajectory.push_back(coords);
				}
			}
		} else if (gui_state.is_drawing_demonstration) {
			double mouse_x, mouse_y;
			glfwGetCursorPos(window, &mouse_x, &mouse_y);

			vec coords = ui2fb({ mouse_x, mouse_y }, window_size, viewport_demos);

			vec last_point = current_trajectory[current_trajectory.size() - 1];
			vec diff = abs(coords - last_point);

			if ((diff(0) > 1e-6) && (diff(1) > 1e-6))
				current_trajectory.push_back(coords);

			// Left mouse button release: end the demonstration creation
			if (!ImGui::IsMouseDown(GLFW_MOUSE_BUTTON_1)) {
				gui_state.is_drawing_demonstration = false;

				if (current_trajectory.size() > 1) {
					demos.push_back(sample_trajectory(current_trajectory, time_steps));

					original_trajectories.push_back(current_trajectory);

					learn(demos, model);

					gui_state.must_recompute_GMR = true;
				}

				current_trajectory.clear();
			}
		}
	}


	// Cleanup
	ImGui_ImplGlfwGL2_Shutdown();
	glfwTerminate();

	return 0;
}
